[PATCH] [mlir][spirv] Account for type conversion failures in scf-to-spirv
authorJakub Kuderski <kubak@google.com>
Mon, 9 Jan 2023 16:35:46 +0000 (11:35 -0500)
committerGianfranco Costamagna <locutusofborg@debian.org>
Thu, 7 Sep 2023 22:43:45 +0000 (00:43 +0200)
Fixes: https://github.com/llvm/llvm-project/issues/59136
Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D141292

Gbp-Pq: Name CVE-2023-29934.patch

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/test/Conversion/SCFToSPIRV/if.mlir

index 10623b63d949d26f3dddd45f6a8503fa81a6514c..81c521d9da59df132d5d919f42d284590514a657 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 
@@ -286,6 +287,10 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
   SmallVector<Type, 8> returnTypes;
   for (auto result : ifOp.getResults()) {
     auto convertedType = typeConverter.convertType(result.getType());
+    if (!convertedType)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("failed to convert type '{0}'", result.getType()));
+
     returnTypes.push_back(convertedType);
   }
   replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
index f937ac6c4e06fb7fd66f21508f0fe93ee1d7bf19..79f53aaa8f2eb01d8f8951d7b3fd03c014099938 100644 (file)
@@ -153,4 +153,18 @@ func.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10
   return
 }
 
+// Memrefs without a spirv storage class are not supported. The conversion
+// should preserve the `scf.if` and not crash.
+func.func @unsupported_yield_type(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %c : i1) {
+// CHECK-LABEL: @unsupported_yield_type
+// CHECK-NEXT:    scf.if
+// CHECK:         spirv.Return
+  %r = scf.if %c -> (memref<8xi32>) {
+    scf.yield %arg0 : memref<8xi32>
+  } else {
+    scf.yield %arg1 : memref<8xi32>
+  }
+  return
+}
+
 } // end module